{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nltk\n",
    "import sklearn_crfsuite\n",
    "\n",
    "from copy import deepcopy\n",
    "from collections import defaultdict\n",
    "\n",
    "from sklearn_crfsuite.metrics import flat_classification_report\n",
    "\n",
    "from ner_evaluation.ner_eval import collect_named_entities\n",
    "from ner_evaluation.ner_eval import compute_metrics\n",
    "from ner_evaluation.ner_eval import compute_precision_recall_wrapper"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train a CRF on the CoNLL 2002 NER Spanish data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "nltk.corpus.conll2002.fileids()\n",
    "train_sents = list(nltk.corpus.conll2002.iob_sents('esp.train'))\n",
    "test_sents = list(nltk.corpus.conll2002.iob_sents('esp.testb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def word2features(sent, i):\n",
    "    word = sent[i][0]\n",
    "    postag = sent[i][1]\n",
    "\n",
    "    features = {\n",
    "        'bias': 1.0,\n",
    "        'word.lower()': word.lower(),\n",
    "        'word[-3:]': word[-3:],\n",
    "        'word[-2:]': word[-2:],\n",
    "        'word.isupper()': word.isupper(),\n",
    "        'word.istitle()': word.istitle(),\n",
    "        'word.isdigit()': word.isdigit(),\n",
    "        'postag': postag,\n",
    "        'postag[:2]': postag[:2],\n",
    "    }\n",
    "    if i > 0:\n",
    "        word1 = sent[i-1][0]\n",
    "        postag1 = sent[i-1][1]\n",
    "        features.update({\n",
    "            '-1:word.lower()': word1.lower(),\n",
    "            '-1:word.istitle()': word1.istitle(),\n",
    "            '-1:word.isupper()': word1.isupper(),\n",
    "            '-1:postag': postag1,\n",
    "            '-1:postag[:2]': postag1[:2],\n",
    "        })\n",
    "    else:\n",
    "        features['BOS'] = True\n",
    "\n",
    "    if i < len(sent)-1:\n",
    "        word1 = sent[i+1][0]\n",
    "        postag1 = sent[i+1][1]\n",
    "        features.update({\n",
    "            '+1:word.lower()': word1.lower(),\n",
    "            '+1:word.istitle()': word1.istitle(),\n",
    "            '+1:word.isupper()': word1.isupper(),\n",
    "            '+1:postag': postag1,\n",
    "            '+1:postag[:2]': postag1[:2],\n",
    "        })\n",
    "    else:\n",
    "        features['EOS'] = True\n",
    "\n",
    "    return features\n",
    "\n",
    "\n",
    "def sent2features(sent):\n",
    "    return [word2features(sent, i) for i in range(len(sent))]\n",
    "\n",
    "\n",
    "def sent2labels(sent):\n",
    "    return [label for token, postag, label in sent]\n",
    "\n",
    "\n",
    "def sent2tokens(sent):\n",
    "    return [token for token, postag, label in sent]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Feature Extraction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 761 ms, sys: 48.4 ms, total: 809 ms\n",
      "Wall time: 809 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "X_train = [sent2features(s) for s in train_sents]\n",
    "y_train = [sent2labels(s) for s in train_sents]\n",
    "\n",
    "X_test = [sent2features(s) for s in test_sents]\n",
    "y_test = [sent2labels(s) for s in test_sents]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 28.7 s, sys: 36.4 ms, total: 28.7 s\n",
      "Wall time: 28.7 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "crf = sklearn_crfsuite.CRF(\n",
    "    algorithm='lbfgs',\n",
    "    c1=0.1,\n",
    "    c2=0.1,\n",
    "    max_iterations=100,\n",
    "    all_possible_transitions=True\n",
    ")\n",
    "crf.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Performance per label type per token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "       B-LOC      0.810     0.784     0.797      1084\n",
      "       I-LOC      0.690     0.637     0.662       325\n",
      "      B-MISC      0.731     0.569     0.640       339\n",
      "      I-MISC      0.699     0.589     0.639       557\n",
      "       B-ORG      0.807     0.832     0.820      1400\n",
      "       I-ORG      0.852     0.786     0.818      1104\n",
      "       B-PER      0.850     0.884     0.867       735\n",
      "       I-PER      0.893     0.943     0.917       634\n",
      "\n",
      "   micro avg      0.813     0.787     0.799      6178\n",
      "   macro avg      0.791     0.753     0.770      6178\n",
      "weighted avg      0.809     0.787     0.796      6178\n",
      "\n"
     ]
    }
   ],
   "source": [
    "y_pred = crf.predict(X_test)\n",
    "labels = list(crf.classes_)\n",
    "labels.remove('O') # remove 'O' label from evaluation\n",
    "sorted_labels = sorted(labels,key=lambda name: (name[1:], name[0])) # group B and I results\n",
    "print(flat_classification_report(y_test, y_pred, labels=sorted_labels, digits=3))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Performance over full named-entity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_sents_labels = []\n",
    "for sentence in test_sents:\n",
    "    sentence = [token[2] for token in sentence]\n",
    "    test_sents_labels.append(sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "index = 2\n",
    "true = collect_named_entities(test_sents_labels[index])\n",
    "pred = collect_named_entities(y_pred[index])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Entity(e_type='MISC', start_offset=12, end_offset=12),\n",
       " Entity(e_type='LOC', start_offset=15, end_offset=15),\n",
       " Entity(e_type='PER', start_offset=37, end_offset=39),\n",
       " Entity(e_type='ORG', start_offset=45, end_offset=46)]"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "true"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Entity(e_type='MISC', start_offset=12, end_offset=12),\n",
       " Entity(e_type='LOC', start_offset=15, end_offset=15),\n",
       " Entity(e_type='PER', start_offset=37, end_offset=39),\n",
       " Entity(e_type='LOC', start_offset=45, end_offset=46)]"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "index = 2\n",
    "true = collect_named_entities(test_sents_labels[index])\n",
    "pred = collect_named_entities(y_pred[index])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Entity(e_type='MISC', start_offset=12, end_offset=12),\n",
       " Entity(e_type='LOC', start_offset=15, end_offset=15),\n",
       " Entity(e_type='PER', start_offset=37, end_offset=39),\n",
       " Entity(e_type='ORG', start_offset=45, end_offset=46)]"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "true"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Entity(e_type='MISC', start_offset=12, end_offset=12),\n",
       " Entity(e_type='LOC', start_offset=15, end_offset=15),\n",
       " Entity(e_type='PER', start_offset=37, end_offset=39),\n",
       " Entity(e_type='LOC', start_offset=45, end_offset=46)]"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "({'strict': {'correct': 3,\n",
       "   'incorrect': 1,\n",
       "   'partial': 0,\n",
       "   'missed': 0,\n",
       "   'spurious': 0,\n",
       "   'precision': 0,\n",
       "   'recall': 0,\n",
       "   'actual': 4,\n",
       "   'possible': 4},\n",
       "  'ent_type': {'correct': 3,\n",
       "   'incorrect': 1,\n",
       "   'partial': 0,\n",
       "   'missed': 0,\n",
       "   'spurious': 0,\n",
       "   'precision': 0,\n",
       "   'recall': 0,\n",
       "   'actual': 4,\n",
       "   'possible': 4},\n",
       "  'partial': {'correct': 4,\n",
       "   'incorrect': 0,\n",
       "   'partial': 0,\n",
       "   'missed': 0,\n",
       "   'spurious': 0,\n",
       "   'precision': 0,\n",
       "   'recall': 0,\n",
       "   'actual': 4,\n",
       "   'possible': 4},\n",
       "  'exact': {'correct': 4,\n",
       "   'incorrect': 0,\n",
       "   'partial': 0,\n",
       "   'missed': 0,\n",
       "   'spurious': 0,\n",
       "   'precision': 0,\n",
       "   'recall': 0,\n",
       "   'actual': 4,\n",
       "   'possible': 4}},\n",
       " {'LOC': {'strict': {'correct': 1,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 0,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 1},\n",
       "   'ent_type': {'correct': 1,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 0,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 1},\n",
       "   'partial': {'correct': 1,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 0,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 1},\n",
       "   'exact': {'correct': 1,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 0,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 1}},\n",
       "  'MISC': {'strict': {'correct': 1,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 0,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 1},\n",
       "   'ent_type': {'correct': 1,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 0,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 1},\n",
       "   'partial': {'correct': 1,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 0,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 1},\n",
       "   'exact': {'correct': 1,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 0,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 1}},\n",
       "  'PER': {'strict': {'correct': 1,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 0,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 1},\n",
       "   'ent_type': {'correct': 1,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 0,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 1},\n",
       "   'partial': {'correct': 1,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 0,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 1},\n",
       "   'exact': {'correct': 1,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 0,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 1}},\n",
       "  'ORG': {'strict': {'correct': 0,\n",
       "    'incorrect': 1,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 0,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 1},\n",
       "   'ent_type': {'correct': 0,\n",
       "    'incorrect': 1,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 0,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 1},\n",
       "   'partial': {'correct': 1,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 0,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 1},\n",
       "   'exact': {'correct': 1,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 0,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 1}}})"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "compute_metrics(true, pred, ['LOC', 'MISC', 'PER', 'ORG'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "to_test = [2,4,12,14]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "index = 2\n",
    "true_named_entities_type = defaultdict(list)\n",
    "pred_named_entities_type = defaultdict(list)\n",
    "\n",
    "for true in collect_named_entities(test_sents_labels[index]):\n",
    "    true_named_entities_type[true.e_type].append(true)\n",
    "\n",
    "for pred in collect_named_entities(y_pred[index]):\n",
    "    pred_named_entities_type[pred.e_type].append(pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "defaultdict(list,\n",
       "            {'MISC': [Entity(e_type='MISC', start_offset=12, end_offset=12)],\n",
       "             'LOC': [Entity(e_type='LOC', start_offset=15, end_offset=15)],\n",
       "             'PER': [Entity(e_type='PER', start_offset=37, end_offset=39)],\n",
       "             'ORG': [Entity(e_type='ORG', start_offset=45, end_offset=46)]})"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "true_named_entities_type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "defaultdict(list,\n",
       "            {'MISC': [Entity(e_type='MISC', start_offset=12, end_offset=12)],\n",
       "             'LOC': [Entity(e_type='LOC', start_offset=15, end_offset=15),\n",
       "              Entity(e_type='LOC', start_offset=45, end_offset=46)],\n",
       "             'PER': [Entity(e_type='PER', start_offset=37, end_offset=39)]})"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_named_entities_type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Entity(e_type='LOC', start_offset=15, end_offset=15)]"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "true_named_entities_type['LOC']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Entity(e_type='LOC', start_offset=15, end_offset=15),\n",
       " Entity(e_type='LOC', start_offset=45, end_offset=46)]"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_named_entities_type['LOC']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "({'strict': {'correct': 1,\n",
       "   'incorrect': 0,\n",
       "   'partial': 0,\n",
       "   'missed': 0,\n",
       "   'spurious': 1,\n",
       "   'precision': 0,\n",
       "   'recall': 0,\n",
       "   'actual': 2,\n",
       "   'possible': 1},\n",
       "  'ent_type': {'correct': 1,\n",
       "   'incorrect': 0,\n",
       "   'partial': 0,\n",
       "   'missed': 0,\n",
       "   'spurious': 1,\n",
       "   'precision': 0,\n",
       "   'recall': 0,\n",
       "   'actual': 2,\n",
       "   'possible': 1},\n",
       "  'partial': {'correct': 1,\n",
       "   'incorrect': 0,\n",
       "   'partial': 0,\n",
       "   'missed': 0,\n",
       "   'spurious': 1,\n",
       "   'precision': 0,\n",
       "   'recall': 0,\n",
       "   'actual': 2,\n",
       "   'possible': 1},\n",
       "  'exact': {'correct': 1,\n",
       "   'incorrect': 0,\n",
       "   'partial': 0,\n",
       "   'missed': 0,\n",
       "   'spurious': 1,\n",
       "   'precision': 0,\n",
       "   'recall': 0,\n",
       "   'actual': 2,\n",
       "   'possible': 1}},\n",
       " {'LOC': {'strict': {'correct': 1,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 1,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 2,\n",
       "    'possible': 1},\n",
       "   'ent_type': {'correct': 1,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 1,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 2,\n",
       "    'possible': 1},\n",
       "   'partial': {'correct': 1,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 1,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 2,\n",
       "    'possible': 1},\n",
       "   'exact': {'correct': 1,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 1,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 2,\n",
       "    'possible': 1}},\n",
       "  'MISC': {'strict': {'correct': 0,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 1,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 0},\n",
       "   'ent_type': {'correct': 0,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 1,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 0},\n",
       "   'partial': {'correct': 0,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 1,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 0},\n",
       "   'exact': {'correct': 0,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 1,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 0}},\n",
       "  'PER': {'strict': {'correct': 0,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 1,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 0},\n",
       "   'ent_type': {'correct': 0,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 1,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 0},\n",
       "   'partial': {'correct': 0,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 1,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 0},\n",
       "   'exact': {'correct': 0,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 1,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 0}},\n",
       "  'ORG': {'strict': {'correct': 0,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 1,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 0},\n",
       "   'ent_type': {'correct': 0,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 1,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 0},\n",
       "   'partial': {'correct': 0,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 1,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 0},\n",
       "   'exact': {'correct': 0,\n",
       "    'incorrect': 0,\n",
       "    'partial': 0,\n",
       "    'missed': 0,\n",
       "    'spurious': 1,\n",
       "    'precision': 0,\n",
       "    'recall': 0,\n",
       "    'actual': 1,\n",
       "    'possible': 0}}})"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "compute_metrics(true_named_entities_type['LOC'], pred_named_entities_type['LOC'], ['LOC', 'MISC', 'PER', 'ORG'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## results over all messages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "metrics_results = {'correct': 0, 'incorrect': 0, 'partial': 0,\n",
    "                   'missed': 0, 'spurious': 0, 'possible': 0, 'actual': 0, 'precision': 0, 'recall': 0}\n",
    "\n",
    "# overall results\n",
    "results = {'strict': deepcopy(metrics_results),\n",
    "           'ent_type': deepcopy(metrics_results),\n",
    "           'partial':deepcopy(metrics_results),\n",
    "           'exact':deepcopy(metrics_results)\n",
    "          }\n",
    "\n",
    "\n",
    "# results aggregated by entity type\n",
    "evaluation_agg_entities_type = {e: deepcopy(results) for e in ['PER', 'LOC', 'MISC', 'ORG']}\n",
    "\n",
    "for true_ents, pred_ents in zip(test_sents_labels, y_pred):\n",
    "    \n",
    "    # compute results for one message\n",
    "    tmp_results, tmp_agg_results = compute_metrics(\n",
    "        collect_named_entities(true_ents), collect_named_entities(pred_ents),  ['LOC', 'MISC', 'PER', 'ORG']\n",
    "    )\n",
    "    \n",
    "    #print(tmp_results)\n",
    "\n",
    "    # aggregate overall results\n",
    "    for eval_schema in results.keys():\n",
    "        for metric in metrics_results.keys():\n",
    "            results[eval_schema][metric] += tmp_results[eval_schema][metric]\n",
    "            \n",
    "    # Calculate global precision and recall\n",
    "        \n",
    "    results = compute_precision_recall_wrapper(results)\n",
    "\n",
    "\n",
    "    # aggregate results by entity type\n",
    " \n",
    "    for e_type in ['PER', 'LOC', 'MISC', 'ORG']:\n",
    "\n",
    "        for eval_schema in tmp_agg_results[e_type]:\n",
    "\n",
    "            for metric in tmp_agg_results[e_type][eval_schema]:\n",
    "                \n",
    "                evaluation_agg_entities_type[e_type][eval_schema][metric] += tmp_agg_results[e_type][eval_schema][metric]\n",
    "                \n",
    "        # Calculate precision recall at the individual entity level\n",
    "                \n",
    "        evaluation_agg_entities_type[e_type] = compute_precision_recall_wrapper(evaluation_agg_entities_type[e_type])\n",
    "                \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'ent_type': {'correct': 2860,\n",
       "  'incorrect': 523,\n",
       "  'partial': 0,\n",
       "  'missed': 176,\n",
       "  'spurious': 139,\n",
       "  'possible': 3559,\n",
       "  'actual': 3522,\n",
       "  'precision': 0.8120386144236229,\n",
       "  'recall': 0.8035965158752458},\n",
       " 'partial': {'correct': 3278,\n",
       "  'incorrect': 0,\n",
       "  'partial': 105,\n",
       "  'missed': 176,\n",
       "  'spurious': 139,\n",
       "  'possible': 3559,\n",
       "  'actual': 3522,\n",
       "  'precision': 0.9456274843838728,\n",
       "  'recall': 0.9357965720708064},\n",
       " 'strict': {'correct': 2783,\n",
       "  'incorrect': 600,\n",
       "  'partial': 0,\n",
       "  'missed': 176,\n",
       "  'spurious': 139,\n",
       "  'possible': 3559,\n",
       "  'actual': 3522,\n",
       "  'precision': 0.7901760363429869,\n",
       "  'recall': 0.78196122506322},\n",
       " 'exact': {'correct': 3278,\n",
       "  'incorrect': 105,\n",
       "  'partial': 0,\n",
       "  'missed': 176,\n",
       "  'spurious': 139,\n",
       "  'possible': 3559,\n",
       "  'actual': 3522,\n",
       "  'precision': 0.9307211811470755,\n",
       "  'recall': 0.9210452374262433}}"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'PER': {'ent_type': {'correct': 651,\n",
       "   'incorrect': 67,\n",
       "   'partial': 0,\n",
       "   'missed': 17,\n",
       "   'spurious': 139,\n",
       "   'possible': 735,\n",
       "   'actual': 857,\n",
       "   'precision': 0.7596266044340724,\n",
       "   'recall': 0.8857142857142857},\n",
       "  'partial': {'correct': 711,\n",
       "   'incorrect': 0,\n",
       "   'partial': 7,\n",
       "   'missed': 17,\n",
       "   'spurious': 139,\n",
       "   'possible': 735,\n",
       "   'actual': 857,\n",
       "   'precision': 0.8337222870478413,\n",
       "   'recall': 0.972108843537415},\n",
       "  'strict': {'correct': 646,\n",
       "   'incorrect': 72,\n",
       "   'partial': 0,\n",
       "   'missed': 17,\n",
       "   'spurious': 139,\n",
       "   'possible': 735,\n",
       "   'actual': 857,\n",
       "   'precision': 0.7537922987164527,\n",
       "   'recall': 0.8789115646258503},\n",
       "  'exact': {'correct': 711,\n",
       "   'incorrect': 7,\n",
       "   'partial': 0,\n",
       "   'missed': 17,\n",
       "   'spurious': 139,\n",
       "   'possible': 735,\n",
       "   'actual': 857,\n",
       "   'precision': 0.8296382730455076,\n",
       "   'recall': 0.9673469387755103}},\n",
       " 'LOC': {'ent_type': {'correct': 855,\n",
       "   'incorrect': 180,\n",
       "   'partial': 0,\n",
       "   'missed': 49,\n",
       "   'spurious': 139,\n",
       "   'possible': 1084,\n",
       "   'actual': 1174,\n",
       "   'precision': 0.7282793867120954,\n",
       "   'recall': 0.7887453874538746},\n",
       "  'partial': {'correct': 1016,\n",
       "   'incorrect': 0,\n",
       "   'partial': 19,\n",
       "   'missed': 49,\n",
       "   'spurious': 139,\n",
       "   'possible': 1084,\n",
       "   'actual': 1174,\n",
       "   'precision': 0.8735093696763203,\n",
       "   'recall': 0.9460332103321033},\n",
       "  'strict': {'correct': 844,\n",
       "   'incorrect': 191,\n",
       "   'partial': 0,\n",
       "   'missed': 49,\n",
       "   'spurious': 139,\n",
       "   'possible': 1084,\n",
       "   'actual': 1174,\n",
       "   'precision': 0.7189097103918228,\n",
       "   'recall': 0.7785977859778598},\n",
       "  'exact': {'correct': 1016,\n",
       "   'incorrect': 19,\n",
       "   'partial': 0,\n",
       "   'missed': 49,\n",
       "   'spurious': 139,\n",
       "   'possible': 1084,\n",
       "   'actual': 1174,\n",
       "   'precision': 0.8654173764906303,\n",
       "   'recall': 0.9372693726937269}},\n",
       " 'MISC': {'ent_type': {'correct': 200,\n",
       "   'incorrect': 89,\n",
       "   'partial': 0,\n",
       "   'missed': 51,\n",
       "   'spurious': 139,\n",
       "   'possible': 340,\n",
       "   'actual': 428,\n",
       "   'precision': 0.4672897196261682,\n",
       "   'recall': 0.5882352941176471},\n",
       "  'partial': {'correct': 257,\n",
       "   'incorrect': 0,\n",
       "   'partial': 32,\n",
       "   'missed': 51,\n",
       "   'spurious': 139,\n",
       "   'possible': 340,\n",
       "   'actual': 428,\n",
       "   'precision': 0.6378504672897196,\n",
       "   'recall': 0.8029411764705883},\n",
       "  'strict': {'correct': 173,\n",
       "   'incorrect': 116,\n",
       "   'partial': 0,\n",
       "   'missed': 51,\n",
       "   'spurious': 139,\n",
       "   'possible': 340,\n",
       "   'actual': 428,\n",
       "   'precision': 0.40420560747663553,\n",
       "   'recall': 0.5088235294117647},\n",
       "  'exact': {'correct': 257,\n",
       "   'incorrect': 32,\n",
       "   'partial': 0,\n",
       "   'missed': 51,\n",
       "   'spurious': 139,\n",
       "   'possible': 340,\n",
       "   'actual': 428,\n",
       "   'precision': 0.6004672897196262,\n",
       "   'recall': 0.7558823529411764}},\n",
       " 'ORG': {'ent_type': {'correct': 1154,\n",
       "   'incorrect': 187,\n",
       "   'partial': 0,\n",
       "   'missed': 59,\n",
       "   'spurious': 139,\n",
       "   'possible': 1400,\n",
       "   'actual': 1480,\n",
       "   'precision': 0.7797297297297298,\n",
       "   'recall': 0.8242857142857143},\n",
       "  'partial': {'correct': 1294,\n",
       "   'incorrect': 0,\n",
       "   'partial': 47,\n",
       "   'missed': 59,\n",
       "   'spurious': 139,\n",
       "   'possible': 1400,\n",
       "   'actual': 1480,\n",
       "   'precision': 0.8902027027027027,\n",
       "   'recall': 0.9410714285714286},\n",
       "  'strict': {'correct': 1120,\n",
       "   'incorrect': 221,\n",
       "   'partial': 0,\n",
       "   'missed': 59,\n",
       "   'spurious': 139,\n",
       "   'possible': 1400,\n",
       "   'actual': 1480,\n",
       "   'precision': 0.7567567567567568,\n",
       "   'recall': 0.8},\n",
       "  'exact': {'correct': 1294,\n",
       "   'incorrect': 47,\n",
       "   'partial': 0,\n",
       "   'missed': 59,\n",
       "   'spurious': 139,\n",
       "   'possible': 1400,\n",
       "   'actual': 1480,\n",
       "   'precision': 0.8743243243243243,\n",
       "   'recall': 0.9242857142857143}}}"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluation_agg_entities_type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ner_evaluation.ner_eval import Evaluator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluator = Evaluator(test_sents_labels, y_pred, ['LOC', 'MISC', 'PER', 'ORG'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2019-03-12 12:00:31 root INFO: Imported 1517 predictions for 1517 true examples\n"
     ]
    }
   ],
   "source": [
    "results, results_agg = evaluator.evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'ent_type': {'correct': 2860,\n",
       "  'incorrect': 523,\n",
       "  'partial': 0,\n",
       "  'missed': 176,\n",
       "  'spurious': 139,\n",
       "  'possible': 3559,\n",
       "  'actual': 3522,\n",
       "  'precision': 0.8120386144236229,\n",
       "  'recall': 0.8035965158752458},\n",
       " 'partial': {'correct': 3278,\n",
       "  'incorrect': 0,\n",
       "  'partial': 105,\n",
       "  'missed': 176,\n",
       "  'spurious': 139,\n",
       "  'possible': 3559,\n",
       "  'actual': 3522,\n",
       "  'precision': 0.9456274843838728,\n",
       "  'recall': 0.9357965720708064},\n",
       " 'strict': {'correct': 2783,\n",
       "  'incorrect': 600,\n",
       "  'partial': 0,\n",
       "  'missed': 176,\n",
       "  'spurious': 139,\n",
       "  'possible': 3559,\n",
       "  'actual': 3522,\n",
       "  'precision': 0.7901760363429869,\n",
       "  'recall': 0.78196122506322},\n",
       " 'exact': {'correct': 3278,\n",
       "  'incorrect': 105,\n",
       "  'partial': 0,\n",
       "  'missed': 176,\n",
       "  'spurious': 139,\n",
       "  'possible': 3559,\n",
       "  'actual': 3522,\n",
       "  'precision': 0.9307211811470755,\n",
       "  'recall': 0.9210452374262433}}"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'LOC': {'ent_type': {'correct': 855,\n",
       "   'incorrect': 180,\n",
       "   'partial': 0,\n",
       "   'missed': 49,\n",
       "   'spurious': 139,\n",
       "   'possible': 1084,\n",
       "   'actual': 1174,\n",
       "   'precision': 0.7282793867120954,\n",
       "   'recall': 0.7887453874538746},\n",
       "  'partial': {'correct': 1016,\n",
       "   'incorrect': 0,\n",
       "   'partial': 19,\n",
       "   'missed': 49,\n",
       "   'spurious': 139,\n",
       "   'possible': 1084,\n",
       "   'actual': 1174,\n",
       "   'precision': 0.8735093696763203,\n",
       "   'recall': 0.9460332103321033},\n",
       "  'strict': {'correct': 844,\n",
       "   'incorrect': 191,\n",
       "   'partial': 0,\n",
       "   'missed': 49,\n",
       "   'spurious': 139,\n",
       "   'possible': 1084,\n",
       "   'actual': 1174,\n",
       "   'precision': 0.7189097103918228,\n",
       "   'recall': 0.7785977859778598},\n",
       "  'exact': {'correct': 1016,\n",
       "   'incorrect': 19,\n",
       "   'partial': 0,\n",
       "   'missed': 49,\n",
       "   'spurious': 139,\n",
       "   'possible': 1084,\n",
       "   'actual': 1174,\n",
       "   'precision': 0.8654173764906303,\n",
       "   'recall': 0.9372693726937269}},\n",
       " 'MISC': {'ent_type': {'correct': 200,\n",
       "   'incorrect': 89,\n",
       "   'partial': 0,\n",
       "   'missed': 51,\n",
       "   'spurious': 139,\n",
       "   'possible': 340,\n",
       "   'actual': 428,\n",
       "   'precision': 0.4672897196261682,\n",
       "   'recall': 0.5882352941176471},\n",
       "  'partial': {'correct': 257,\n",
       "   'incorrect': 0,\n",
       "   'partial': 32,\n",
       "   'missed': 51,\n",
       "   'spurious': 139,\n",
       "   'possible': 340,\n",
       "   'actual': 428,\n",
       "   'precision': 0.6378504672897196,\n",
       "   'recall': 0.8029411764705883},\n",
       "  'strict': {'correct': 173,\n",
       "   'incorrect': 116,\n",
       "   'partial': 0,\n",
       "   'missed': 51,\n",
       "   'spurious': 139,\n",
       "   'possible': 340,\n",
       "   'actual': 428,\n",
       "   'precision': 0.40420560747663553,\n",
       "   'recall': 0.5088235294117647},\n",
       "  'exact': {'correct': 257,\n",
       "   'incorrect': 32,\n",
       "   'partial': 0,\n",
       "   'missed': 51,\n",
       "   'spurious': 139,\n",
       "   'possible': 340,\n",
       "   'actual': 428,\n",
       "   'precision': 0.6004672897196262,\n",
       "   'recall': 0.7558823529411764}},\n",
       " 'PER': {'ent_type': {'correct': 651,\n",
       "   'incorrect': 67,\n",
       "   'partial': 0,\n",
       "   'missed': 17,\n",
       "   'spurious': 139,\n",
       "   'possible': 735,\n",
       "   'actual': 857,\n",
       "   'precision': 0.7596266044340724,\n",
       "   'recall': 0.8857142857142857},\n",
       "  'partial': {'correct': 711,\n",
       "   'incorrect': 0,\n",
       "   'partial': 7,\n",
       "   'missed': 17,\n",
       "   'spurious': 139,\n",
       "   'possible': 735,\n",
       "   'actual': 857,\n",
       "   'precision': 0.8337222870478413,\n",
       "   'recall': 0.972108843537415},\n",
       "  'strict': {'correct': 646,\n",
       "   'incorrect': 72,\n",
       "   'partial': 0,\n",
       "   'missed': 17,\n",
       "   'spurious': 139,\n",
       "   'possible': 735,\n",
       "   'actual': 857,\n",
       "   'precision': 0.7537922987164527,\n",
       "   'recall': 0.8789115646258503},\n",
       "  'exact': {'correct': 711,\n",
       "   'incorrect': 7,\n",
       "   'partial': 0,\n",
       "   'missed': 17,\n",
       "   'spurious': 139,\n",
       "   'possible': 735,\n",
       "   'actual': 857,\n",
       "   'precision': 0.8296382730455076,\n",
       "   'recall': 0.9673469387755103}},\n",
       " 'ORG': {'ent_type': {'correct': 1154,\n",
       "   'incorrect': 187,\n",
       "   'partial': 0,\n",
       "   'missed': 59,\n",
       "   'spurious': 139,\n",
       "   'possible': 1400,\n",
       "   'actual': 1480,\n",
       "   'precision': 0.7797297297297298,\n",
       "   'recall': 0.8242857142857143},\n",
       "  'partial': {'correct': 1294,\n",
       "   'incorrect': 0,\n",
       "   'partial': 47,\n",
       "   'missed': 59,\n",
       "   'spurious': 139,\n",
       "   'possible': 1400,\n",
       "   'actual': 1480,\n",
       "   'precision': 0.8902027027027027,\n",
       "   'recall': 0.9410714285714286},\n",
       "  'strict': {'correct': 1120,\n",
       "   'incorrect': 221,\n",
       "   'partial': 0,\n",
       "   'missed': 59,\n",
       "   'spurious': 139,\n",
       "   'possible': 1400,\n",
       "   'actual': 1480,\n",
       "   'precision': 0.7567567567567568,\n",
       "   'recall': 0.8},\n",
       "  'exact': {'correct': 1294,\n",
       "   'incorrect': 47,\n",
       "   'partial': 0,\n",
       "   'missed': 59,\n",
       "   'spurious': 139,\n",
       "   'possible': 1400,\n",
       "   'actual': 1480,\n",
       "   'precision': 0.8743243243243243,\n",
       "   'recall': 0.9242857142857143}}}"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_agg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "NER-Evaluation",
   "language": "python",
   "name": "ner-evaluation"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
